{
 "metadata": {
  "name": ""
 },
 "nbformat": 3,
 "nbformat_minor": 0,
 "worksheets": [
  {
   "cells": [
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "%matplotlib inline"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 1
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import os\n",
      "import re\n",
      "import bz2\n",
      "import numpy as np\n",
      "from os.path import expanduser, join, exists\n",
      "import matplotlib.pyplot as plt\n",
      "\n",
      "DATA_FOLDER = expanduser('~/data/mnist8m')\n",
      "MNIST8M_SRC_URL = ('http://www.csie.ntu.edu.tw/~cjlin/libsvmtools/'\n",
      "                   'datasets/multiclass/mnist8m.bz2')\n",
      "MNIST8M_SRC_FILENAME = MNIST8M_SRC_URL.rsplit('/', 1)[1]\n",
      "MNIST8M_SRC_FILEPATH = join(DATA_FOLDER, MNIST8M_SRC_FILENAME)\n",
      "MNIST8M_MMAP_DATA_FILENAME = \"mnist8m_data.mmap\"\n",
      "MNIST8M_MMAP_DATA_FILEPATH = join(DATA_FOLDER, MNIST8M_MMAP_DATA_FILENAME)\n",
      "\n",
      "MNIST8M_MMAP_LABELS_FILENAME = \"mnist8m_labels.mmap\"\n",
      "MNIST8M_MMAP_LABELS_FILEPATH = join(DATA_FOLDER, MNIST8M_MMAP_LABELS_FILENAME)\n",
      "\n",
      "CHUNK_FILENAME_PREFIX = \"mnist8m-chunk-\"\n",
      "\n",
      "CHUNK_SIZE = 100000"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 2
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Download the `mnist8m.bz2` source file into the data folder if not previously downloaded:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "if not exists(DATA_FOLDER):\n",
      "    os.makedirs(DATA_FOLDER)\n",
      "\n",
      "if not exists(MNIST8M_SRC_FILEPATH):\n",
      "    cmd = \"(cd '%s' && wget -c '%s')\" % (DATA_FOLDER, MNIST8M_SRC_URL)\n",
      "    print(cmd)\n",
      "    os.system(cmd)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 2
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Decompress and chunk the source svmlight formatted data file to make it easier to process it in parallel:"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "chunk_filenames = [fn for fn in os.listdir(DATA_FOLDER)\n",
      "                   if fn.startswith(CHUNK_FILENAME_PREFIX)]\n",
      "chunk_filenames.sort()\n",
      "\n",
      "\n",
      "if not chunk_filenames:\n",
      "    chunk_filenames = []\n",
      "    with bz2.BZ2File(MNIST8M_SRC_FILEPATH) as source:\n",
      "        target, line_no, chunk_idx = None, 0, 0\n",
      "        for line in source:\n",
      "            line_no += 1\n",
      "            if target is None:\n",
      "                chunk_filename = \"%s%03d.svmlight\" % (\n",
      "                    CHUNK_FILENAME_PREFIX, chunk_idx)\n",
      "                chunk_filename = join(DATA_FOLDER, chunk_filename)\n",
      "                target = open(chunk_filename, 'wb')\n",
      "                chunk_idx += 1\n",
      "                chunk_filenames.append(chunk_filename)\n",
      "                \n",
      "            target.write(line)\n",
      "                \n",
      "            if line_no >= CHUNK_SIZE:\n",
      "                target.close()\n",
      "                target, line_no = None, 0\n",
      "        if target is not None:\n",
      "            target.close()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 3
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Memory mapping utility to load or create the mmap data files"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def get_mmap_data(data_filepath, labels_filepath, mode=None, order='C'):\n",
      "    import numpy as np\n",
      "    from os.path import exists\n",
      "\n",
      "    n_samples, n_features = 8100000, 28 ** 2\n",
      "    shape = (n_samples, n_features)\n",
      "    \n",
      "    # Reopen the existing data map if it already exists, otherwise create it\n",
      "    if mode is None:\n",
      "        mode_data = 'r+' if exists(data_filepath) else 'w+'\n",
      "    else:\n",
      "        mode_data = mode\n",
      "    data = np.memmap(data_filepath, shape=shape,\n",
      "                     dtype=np.float32, mode=mode_data, order=order)\n",
      "    \n",
      "    if mode is None:\n",
      "        mode_labels = 'r+' if exists(labels_filepath) else 'w+'\n",
      "    else:\n",
      "        mode_labels = mode\n",
      "    labels = np.memmap(labels_filepath, shape=n_samples,\n",
      "                       dtype=np.int32, mode=mode_labels)\n",
      "    return data, labels"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 2
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Trigger the creation of the mmap files with zero data if missing.\n",
      "data, labels = get_mmap_data(MNIST8M_MMAP_DATA_FILEPATH,\n",
      "                             MNIST8M_MMAP_LABELS_FILEPATH)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 3
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Initialization of the IPython.parallel cluster client"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from IPython.parallel import Client"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 6
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "client = Client()\n",
      "lb_view = client.load_balanced_view()\n",
      "len(lb_view)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 7,
       "text": [
        "17"
       ]
      }
     ],
     "prompt_number": 7
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "# Ship the definition of the mmap loading helper function to the\n",
      "# worker processes\n",
      "client[:][get_mmap_data.__name__] = get_mmap_data"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 8
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Assigning engines to fixed CPUs"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from collections import defaultdict\n",
      "\n",
      "def collect_host_infos(direct_view):\n",
      "    # Collect the information\n",
      "    def get_pid_host_cpucount():\n",
      "        import os\n",
      "        import socket\n",
      "        import multiprocessing as mp\n",
      "        return socket.gethostname(), mp.cpu_count(), os.getpid()\n",
      "    \n",
      "    engine_info = direct_view.apply(get_pid_host_cpucount).get_dict()\n",
      "\n",
      "    # Group info by hosts in nested dicts\n",
      "    info_by_host = defaultdict(dict)\n",
      "    for engine_id, info in engine_info.items():\n",
      "        host_info = info_by_host[info[0]]\n",
      "        host_info['cpu_count'] = info[1]\n",
      "        host_info.setdefault('engine_pids', {})[engine_id] = info[2]\n",
      "    \n",
      "    return info_by_host\n",
      "\n",
      "def assign_engines(direct_view, verbose=False):\n",
      "    \"\"\"Assign each engine to a specific CPU on it hosts.\"\"\"\n",
      "\n",
      "    def taskset(engine_pids, verbose=False):\n",
      "        import os\n",
      "        import socket\n",
      "        import multiprocessing as mp\n",
      "        cpu_count = mp.cpu_count()\n",
      "        for i, pid in enumerate(engine_pids):\n",
      "            cpu_id = i % cpu_count\n",
      "            cmd = \"taskset -p -c {} {}\".format(cpu_id, pid)\n",
      "            if os.system(cmd) != 0:\n",
      "                err_msg = \"Command '{}' failed on host: {}\"\n",
      "                raise RuntimeError(err_msg.format(\n",
      "                    cmd, socket.gethostname()))\n",
      "    \n",
      "    tasks = []\n",
      "    host_infos = collect_host_infos(direct_view)\n",
      "    for host_info in host_infos.values():\n",
      "        client = direct_view.client\n",
      "        pid_map = host_info['engine_pids']\n",
      "        engine_pids = list(sorted(pid_map.values()))    \n",
      "        ctl_engine_view = client[pid_map.keys()[0]]\n",
      "        tasks.append(ctl_engine_view.apply(taskset, engine_pids, verbose))\n",
      "        \n",
      "    # Wait for completion\n",
      "    [t.r for t in tasks]\n",
      "\n",
      "\n",
      "assign_engines(client[:], verbose=True)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 9
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Initialization of the contiguous mmap data files from the svmlight formatted chunks"
     ]
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "We use IPython.parallel to parse the svmlight formatted chunks and store the resulting data into a single precision float, C contiguous array suitable for training ensembles of trees."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def parse_chunk(task_info):\n",
      "    from sklearn.datasets import load_svmlight_file\n",
      "    from os.path import join\n",
      "    chunk_index, chunk_filename, chunk_size = task_info[:3]\n",
      "    data_filepath, labels_filepath = task_info[3:]\n",
      "    data, labels = get_mmap_data(data_filepath,\n",
      "                                 labels_filepath,\n",
      "                                 mode='r+')\n",
      "    X, y = load_svmlight_file(chunk_filename, n_features=28 ** 2)\n",
      "    region = slice(chunk_index * chunk_size,\n",
      "                   (chunk_index + 1) * chunk_size)\n",
      "    data[region, :] = X.toarray() / 255.\n",
      "    labels[region] = y\n",
      "\n",
      "    \n",
      "task_infos = [(chunk_index,\n",
      "               join(DATA_FOLDER, chunk_filename),\n",
      "               CHUNK_SIZE,\n",
      "               MNIST8M_MMAP_DATA_FILEPATH,\n",
      "               MNIST8M_MMAP_LABELS_FILEPATH)\n",
      "              for chunk_index, chunk_filename in enumerate(chunk_filenames)]\n",
      "\n",
      "result = lb_view.map(parse_chunk, task_infos)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 10
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "result.get()\n",
      "%time data.max(), labels.max()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "CPU times: user 7.03 s, sys: 4.14 s, total: 11.2 s\n",
        "Wall time: 11.2 s\n"
       ]
      },
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 11,
       "text": [
        "(memmap(1.0, dtype=float32), memmap(9, dtype=int32))"
       ]
      }
     ],
     "prompt_number": 11
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Training ensembles of trees in parallel"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "%time data.max(), labels.max()"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "CPU times: user 7.63 s, sys: 3.62 s, total: 11.3 s\n",
        "Wall time: 11.3 s\n"
       ]
      },
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 4,
       "text": [
        "(memmap(1.0, dtype=float32), memmap(9, dtype=int32))"
       ]
      }
     ],
     "prompt_number": 4
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "data_train, labels_train = data[:-100000], labels[:-100000]\n",
      "data_test, labels_test = data[-100000:], labels[-100000:]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 5
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "data_train.shape"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 6,
       "text": [
        "(8000000, 784)"
       ]
      }
     ],
     "prompt_number": 6
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "data_test.shape"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 7,
       "text": [
        "(100000, 784)"
       ]
      }
     ],
     "prompt_number": 7
    },
    {
     "cell_type": "markdown",
     "metadata": {},
     "source": [
      "Backup the dataset in binary format."
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import bloscpack as bp\n",
      "\n",
      "for i in range(81):\n",
      "    chunk_location = slice(i * int(1e5), (i + 1) * int(1e5))\n",
      "    filename = join(DATA_FOLDER, \"mnist8m-chunk-%02d.blp\" % i)\n",
      "    bp.pack_ndarray_file(data[chunk_location], filename)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "filename = join(DATA_FOLDER, \"mnist8m-labels.blp\")\n",
      "bp.pack_ndarray_file(labels, filename)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 21
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "def train_trees(train_trees_params):\n",
      "    from sklearn.ensemble import ExtraTreesClassifier\n",
      "    \n",
      "    n_estimators, random_state, n_samples = train_trees_params[:3]\n",
      "    data_file, labels_file  = train_trees_params[3:]\n",
      "    data, labels = get_mmap_data(data_file, labels_file, mode='r')\n",
      "    data_train, labels_train = data[:n_samples], labels[:n_samples]\n",
      "    ensemble = ExtraTreesClassifier(n_estimators=n_estimators,\n",
      "                                    random_state=random_state)\n",
      "    ensemble.fit(data_train, labels_train)\n",
      "    return ensemble\n",
      "\n",
      "train_result = lb_view.map(\n",
      "    train_trees,\n",
      "    [(1, i, int(8e6),\n",
      "      MNIST8M_MMAP_DATA_FILEPATH,\n",
      "      MNIST8M_MMAP_LABELS_FILEPATH) for i in range(16)])"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 17
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "train_result.ready(), train_result.progress, len(train_result), train_result.elapsed"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 79,
       "text": [
        "(True, 16, 16, 2798.811068)"
       ]
      }
     ],
     "prompt_number": 79
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "all_ensembles = train_result.get()\n",
      "first_clf = all_ensembles[0]\n",
      "first_clf.n_classes_"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 81,
       "text": [
        "10"
       ]
      }
     ],
     "prompt_number": 81
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "individual_scores = [clf.score(data_test, labels_test) for clf in all_ensembles]\n",
      "np.mean(individual_scores), np.std(individual_scores)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 82,
       "text": [
        "(0.97396062500000002, 0.0012030714273786978)"
       ]
      }
     ],
     "prompt_number": 82
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from copy import copy\n",
      "\n",
      "final_clf = copy(first_clf)\n",
      "for clf in all_ensembles[1:]:\n",
      "    final_clf.estimators_ += clf.estimators_\n",
      "final_clf.n_estimators = len(final_clf.estimators_)\n",
      "    \n",
      "%time final_clf_score = final_clf.score(data_test, labels_test)\n",
      "print(final_clf_score)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "CPU times: user 2.56 s, sys: 0 ns, total: 2.56 s\n",
        "Wall time: 1.8 s\n",
        "0.99998\n"
       ]
      }
     ],
     "prompt_number": 83
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "!rm -f $DATA_FOLDER/model.*"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 84
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.externals import joblib\n",
      "\n",
      "model_files = joblib.dump(final_clf, DATA_FOLDER + '/model.pkl')"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 85
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import os\n",
      "\n",
      "model_size = sum([os.stat(f).st_size for f in model_files])\n",
      "print(\"Total model size: {}MB\".format(int(model_size / 1e6)))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "Total model size: 1236MB\n"
       ]
      }
     ],
     "prompt_number": 86
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "model_files = joblib.dump(final_clf, DATA_FOLDER + '/model.pkl.gz', compress=1)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 87
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "model_size = sum([os.stat(f).st_size for f in model_files])\n",
      "print(\"Total model size: {}MB\".format(int(model_size / 1e6)))"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "Total model size: 159MB\n"
       ]
      }
     ],
     "prompt_number": 88
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Single process training and evaluation (sanity check)"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.ensemble import ExtraTreesClassifier\n",
      "\n",
      "ensemble = ExtraTreesClassifier(n_estimators=1, random_state=0)\n",
      "%time ensemble.fit(data_train, labels_train)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "%time ensemble.score(data_test, labels_test)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.externals import joblib\n",
      "\n",
      "joblib.dump(ensemble, DATA_FOLDER + '/model.pkl')\n",
      "!ls -lh $DATA_FOLDER/model.*.npy"
     ],
     "language": "python",
     "metadata": {},
     "outputs": []
    },
    {
     "cell_type": "heading",
     "level": 2,
     "metadata": {},
     "source": [
      "Original MNIST data"
     ]
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.datasets import fetch_mldata\n",
      "\n",
      "mnist = fetch_mldata('MNIST original')\n",
      "\n",
      "mnist_data = np.ascontiguousarray(mnist.data, dtype=np.float32) / 255.\n",
      "\n",
      "X_train, y_train = mnist_data[:-10000], mnist.target[:-10000]\n",
      "X_test, y_test = mnist_data[-10000:], mnist.target[-10000:]\n",
      "del mnist_data\n",
      "del mnist"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 3
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "X_train.shape, X_test.shape"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 8,
       "text": [
        "((60000, 784), (10000, 784))"
       ]
      }
     ],
     "prompt_number": 8
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "ls -lh $DATA_FOLDER/*.gz"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "-rw-r--r--  1 ogrisel  staff   152M Aug 29 16:19 /Users/ogrisel/data/mnist8m/model-mnist8m-16.pkl.gz\r\n"
       ]
      }
     ],
     "prompt_number": 8
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.externals import joblib\n",
      "\n",
      "et_mnist8m_16 = joblib.load(DATA_FOLDER + '/model-mnist8m-16.pkl.gz')"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 21
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "et_mnist8m_16.score(X_test, y_test)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 10,
       "text": [
        "0.98240000000000005"
       ]
      }
     ],
     "prompt_number": 10
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "all_n_estimators = [1, 2, 4, 8, 16]\n",
      "et16_score_path = [m.score(X_test, y_test)\n",
      "                   for m in model_path(et_mnist8m_16, all_n_estimators)]\n",
      "et16_score_path"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 14,
       "text": [
        "[0.90190000000000003,\n",
        " 0.91069999999999995,\n",
        " 0.95950000000000002,\n",
        " 0.97460000000000002,\n",
        " 0.98240000000000005]"
       ]
      }
     ],
     "prompt_number": 14
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "x = np.array(all_n_estimators)\n",
      "plt.plot(x, np.array(et16_score_path))\n",
      "plt.ylim(None, 1)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 19,
       "text": [
        "(0.90000000000000002, 1)"
       ]
      },
      {
       "metadata": {},
       "output_type": "display_data",
       "png": "iVBORw0KGgoAAAANSUhEUgAAAXsAAAD9CAYAAABdoNd6AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAH+pJREFUeJzt3X1U1HWix/H3IJgVauX6kAwuJMSDEGIYPeiCZZfNW1yf\naildUbHltHVW29bV/uis7t5FXXfvydY9e8gFDfOS7eZGa8oWt0hbNbajW10hJZdRwMe8iCnawPi7\nf/xWEoHhwYHfDPN5ndMRht+Mn9mzfPye7/c735/NMAwDERHp0wKsDiAiIj1PZS8i4gdU9iIifkBl\nLyLiB1T2IiJ+QGUvIuIHOiz7+fPnM3z4cOLj49u95kc/+hGRkZEkJCSwb9++5seLi4uJjo4mMjKS\nVatWeSaxiIh0WYdlP2/ePIqLi9v9+bZt2/jiiy+orKzk5Zdf5qmnngLA5XLxzDPPUFxcTHl5OYWF\nhVRUVHguuYiIdFqHZT9x4kRuvvnmdn/+1ltvkZmZCUBycjJnzpzh+PHjlJWVERERQVhYGEFBQWRk\nZFBUVOS55CIi0mnXPGdfW1tLaGho8/d2u53a2lqOHj3a5uMiItL7Aj3xItdy4oLNZvNEBBERv9OV\n7r3mkX1ISAjV1dXN39fU1GC321s9Xl1djd1ub/M1DMPw+v9+9rOfWZ5BOZXTl3P6QkZfytlV11z2\n6enpFBQUALBnzx5uuukmhg8fTlJSEpWVlTgcDpxOJ5s3byY9Pf1a/zoREemGDqdxHn/8cT744AO+\n/PJLQkNDWb58OY2NjQBkZ2czZcoUtm3bRkREBDfeeCPr1683XzgwkLVr15KWlobL5SIrK4uYmJie\nfTciItKmDsu+sLCwwxdZu3Ztm48/9NBDPPTQQ11P5YVSU1OtjtApyulZyuk5vpARfCdnV9mM7kz+\neDKAzdat+ScREX/W1e7UcQkiIn5AZS8i4gdU9iIifkBlLyLiB1T2IiJ+QGUvIuIHVPYiIn5AZS8i\n4gdU9iIifkBlLyLiB1T2IiJ+QGUvIuIHVPYiIn5AZS8i4gdU9iIifkBlLyLiB1T2IiJ+QGUvIuIH\nVPYiIn5AZS8i4gdU9iIifkBlLyLiB1T2IiJ+QGUvIuIHVPYiIn5AZS8i4gdU9iIifkBlLyLiB1T2\nIiJ+QGUvIuIHVPYiIn5AZS8i4gcCrQ4gIiLta2iAo0ehtrbln12lshcRsYDLBSdOfFPgV5b5lV9f\nuAAjR0JIiPnf5a+7ymYYhuH5t9GFADYbFkcQEfEYw4D6+taj8au/PnUKbrmlZYG39fUtt4DN1vrv\n6Wp3dlj2xcXFLFq0CJfLxYIFC1iyZEmLn9fV1TF//nz++c9/MmDAAPLz8xkzZgwAK1as4NVXXyUg\nIID4+HjWr1/Pddddd02BRUSs8vXXcOxY6ymVq8u8X7+Wxd1WmY8YAUFB3c/i0bJ3uVxERUVRUlJC\nSEgI48ePp7CwkJiYmOZrFi9ezKBBg3jhhRc4cOAATz/9NCUlJTgcDu6//34qKiq47rrr+N73vseU\nKVPIzMy8psAiIp526RJ8+WXbxX3lY/X1Zkm3Nwq//OfAgT2fuavd6XbOvqysjIiICMLCwgDIyMig\nqKioRdlXVFSwdOlSAKKionA4HJw6dYpBgwYRFBREQ0MD/fr1o6GhgZDuTDSJiFyD8+fdz4kfPWqO\n1gcObF3gSUktHxs6FAJ8dA+j27Kvra0lNDS0+Xu73c5HH33U4pqEhAS2bNnChAkTKCsr4/Dhw9TU\n1JCYmMhzzz3HqFGjuP7660lLS2Py5Mlt/j3Lli1r/jo1NZXU1NTuvyMR8QtNTXD8eMcLnI2NrUfe\nYWFw333fPHbrrTBggNXvyL3S0lJKS0u7/Xy3ZW9ra1XgKkuXLmXhwoUkJiYSHx9PYmIi/fr149Ch\nQ7z44os4HA4GDx7Mo48+yqZNm5g1a1ar17iy7EXEvxkG1NV1vMB5+jR861utp1FSU1s+dtNNbS9w\n+pqrB8LLly/v0vPdln1ISAjV1dXN31dXV2O321tcM3DgQPLz85u/Dw8P57bbbuPtt9/m3nvvZciQ\nIQBMnz6dXbt2tVn2IuIfLl5se8/4lWV+9Cj07996SiUuDtLSvnls+HAI1ObxTnP7P1VSUhKVlZU4\nHA5GjhzJ5s2bKSwsbHFNfX09119/Pf3792fdunWkpKQQHBxMVFQUv/jFL7hw4QIDBgygpKSEu+66\nq0ffjIhY49IlOHmy4wXOc+fMsr56l8q4cS0fu/FGq99R3+O27AMDA1m7di1paWm4XC6ysrKIiYkh\nNzcXgOzsbMrLy5k7dy42m424uDjy8vIAGDt2LHPmzCEpKYmAgADGjRvHD37wg55/RyLiUV991fEC\n5/Hj5nTJ1VMqd9/d8rFvfatvTKn4In2oSsRPNTaau1A6WuA0DPfbDC8vcPbvb/U78i8e/1BVT1PZ\ni3iWYZiLlx0tcNbVwbBhHe8ZHzRIo3FvpLIX6cPaOxTr6gXOG25ofxR++ethw8xPeopvUtmL+CB3\nh2Jd+djFiy0XONsq85Ej4frrrX5H0tNU9iJexDDg7Fn3c+JXH4rlbkqlvUOxxP+o7EV6SXcPxWrr\n62s9FEv8j8pexEMOHYLPP2+/zM+eNUva3cmGvXUolvgflb3INbh4Ed54A3Jz4eDB1h/2ubLMfflQ\nLPF9KnuRbjhwAF5+GQoKIDERsrMhPV1TK+K9PHrEsUhf9vXX8Oc/m6P48nKYNw/27IHRo61OJuJ5\nKnvxO198AevWwYYN5uFaTz0FU6fqE6DSt6nsxS80NkJRkTmK/+QTyMyEnTvh9tutTibSO1T20qdV\nVZmj+PXrzWLPzobp073/RhUinqaylz6nqQn+8hdzFP/xx/D978N778EVd9MU8Tsqe+kzjhwxR/H5\n+RAebo7i//xnHR0gAip78XFNTbB9uzmK370bZs2Cv/7VXHgVkW+o7MUn1dRAXh784Q/mh5yys+H1\n183THkWkNZW9+AyXyxy15+aaO2kyMmDrVkhIsDqZiPdT2YvXO3bsm1H80KHmKH7TJggOtjqZiO9Q\n2YtXunQJ3n3XHMW//z489hhs2WKeVSMiXaeyF69y4oS5J37dOhg82BzFv/KKTo4UuVYqe7HcpUvm\n6D031xzNT58OhYUwfrxu1CHiKTr1Uixz6pR5Ps3LL5t74bOzYfZsc0QvIu7p1EvxaoYBH3xgjuK3\nbzcPICsogLvv1ihepCdpZC+94v/+z5x7z801b9OXnW0eY3DzzVYnE/FNGtmL1zAM+NvfzIL/y1/g\nkUfM7ZP33adRvEhv08hePK6uDjZuNEu+qckcxWdmwpAhVicT6Ts0shdLGIZ5l6fcXHjzTXjoIfjd\n7yAlRaN4EW+gkb1ck/p6ePVVs+QvXIAf/ADmzjU/6SoiPUc3HJceZxjmOfG5ufDGG/Dgg+ZUzaRJ\nEBBgdToR/6BpHOkxX30F//3fZsmfOWOO4j//HIYPtzqZiHREI3vp0N69ZsG//ro5es/ONkfzGsWL\nWEcje/GI8+fNIwtyc+HkSXjySdi/H0aOtDqZiHSHRvbSwiefmAX/2mswcaI5ik9LMz8IJSLeQyN7\n6bKGBnOKJjfXvAPUggXw6adgt1udTEQ8RSN7P7Z/v1nwmzaZZ9NkZ8OUKRCoIYCI1+tqd3a4xFZc\nXEx0dDSRkZGsWrWq1c/r6uqYNm0aCQkJJCcns3///uafnTlzhpkzZxITE0NsbCx79uzpdDDpGRcu\nmJ9unTDBXGQdPNhcgH37bUhPV9GL9FVuR/Yul4uoqChKSkoICQlh/PjxFBYWEhMT03zN4sWLGTRo\nEC+88AIHDhzg6aefpqSkBIDMzExSUlKYP38+TU1NnD9/nsFXnV+rkX3v+PxzcxS/cSMkJZmj+Icf\nhqAgq5OJSHd4dGRfVlZGREQEYWFhBAUFkZGRQVFRUYtrKioqmDRpEgBRUVE4HA5OnTpFfX09O3fu\nZP78+QAEBga2KnrpWV9/be6LT0mB1FTzzPi//x2Ki2HaNBW9iD9xW/a1tbWEhoY2f2+326mtrW1x\nTUJCAlu2bAHMfxwOHz5MTU0NVVVVDB06lHnz5jFu3DiefPJJGhoaeuAtyNUqK2HxYggNhfx8eOYZ\nOHIEcnIgPNzqdCJiBbcztLZOnGC1dOlSFi5cSGJiIvHx8SQmJtKvXz+cTid79+5l7dq1jB8/nkWL\nFrFy5Up+/vOft3qNZcuWNX+dmppKampql9+ImDtonn0WPvvMPJ9m1y6IiLA6lYh4QmlpKaWlpd1+\nvts5+z179rBs2TKKi4sBWLFiBQEBASxZsqTdFwwPD+ezzz7j3Llz3HPPPVRVVQHw4YcfsnLlSrZu\n3doygObsPebf/x2Sk2HJErjuOqvTiEhP8uicfVJSEpWVlTgcDpxOJ5s3byY9Pb3FNfX19TidTgDW\nrVtHSkoKwcHBjBgxgtDQUA4ePAhASUkJY8aM6er7kU6qrYXdu+EnP1HRi0hrbqdxAgMDWbt2LWlp\nabhcLrKysoiJiSE3NxeA7OxsysvLmTt3Ljabjbi4OPLy8pqf/9vf/pZZs2bhdDoZPXo069ev79l3\n48deeQUefRRuuMHqJCLijfShqj7AMCAy0vxwVHKy1WlEpDd4/ENV4v127oQBA+Cuu6xOIiLeSmXf\nB+Tlwfz5uv2fiLRP0zg+7uxZGDUKDh6EYcOsTiMivUXTOH5m82Z44AEVvYi4p7L3cZencERE3FHZ\n+7D9+6G62ry5iIiIOyp7H7Z+PWRm6lhiEemYFmh9lNNpHnT24YfmHnsR8S9aoPUTb78N0dEqehHp\nHJW9j9LCrIh0haZxfNDRozBmjHlz8BtvtDqNiFhB0zh+oKDAPPRMRS8inaV9HD7GMMy7TxUUWJ1E\nRHyJRvY+5sMPza2WOt1SRLpCZe9j8vMhK0uHnolI12iB1oecPQvf/jYcOKCzcET8nRZo+7DXX4dJ\nk1T0ItJ1Knsfkp+vvfUi0j0qex9RUQEOB3z3u1YnERFfpLL3Efn5OvRMRLpPC7Q+oLER7HbzXrO3\n3251GhHxBlqg7YPeftsseRW9iHSXyt4HXN5bLyLSXZrG8XLHjkFsrHlHquBgq9OIiLfQNE4fU1AA\nM2ao6EXk2mhvhxe7fOjZhg1WJxERX6eRvRf7298gIADuvtvqJCLi61T2XkyHnomIp2iB1kt99RWM\nGgWffw7Dh1udRkS8jRZo+4jXX4eUFBW9iHiGyt5LaW+9iHiSpnG80Oefm0cZV1frLBwRaZumcfqA\n/HyYM0dFLyKeo5G9l2lshNBQ+OADiIqyOo2IeCuN7H3ctm0QEaGiFxHPUtl7Gd2NSkR6QodlX1xc\nTHR0NJGRkaxatarVz+vq6pg2bRoJCQkkJyezf//+Fj93uVwkJibyyCOPeC51H3X8OOzYAY89ZnUS\nEelr3Ja9y+XimWeeobi4mPLycgoLC6moqGhxTU5ODuPGjeOTTz6hoKCAhQsXtvj5mjVriI2NxaaP\ngXaooACmT9ehZyLieW7LvqysjIiICMLCwggKCiIjI4OioqIW11RUVDBp0iQAoqKicDgcnDp1CoCa\nmhq2bdvGggULtAjbgcuHnmlvvYj0BLdlX1tbS2hoaPP3drud2traFtckJCSwZcsWwPzH4fDhw9TU\n1ADw7LPPsnr1agICtDTQkd27zT/vucfaHCLSN7ndyd2ZqZelS5eycOFCEhMTiY+PJzExkYCAALZu\n3cqwYcNITEyktLTU7WssW7as+evU1FRSU1M7k71PycszF2Y12yUibSktLe2wS91xu89+z549LFu2\njOLiYgBWrFhBQEAAS5YsafcFw8PD+fTTT1mxYgUbN24kMDCQixcvcvbsWWbMmEFBQUHLANpnz7lz\n5t76igoYMcLqNCLiCzy6zz4pKYnKykocDgdOp5PNmzeTnp7e4pr6+nqcTicA69atIyUlhYEDB5KT\nk0N1dTVVVVW89tpr3H///a2KXkx//CN85zsqehHpOW6ncQIDA1m7di1paWm4XC6ysrKIiYkhNzcX\ngOzsbMrLy5k7dy42m424uDjy8vLafC3txmlfXh4sXmx1ChHpy3RcgsUOHDCPMq6uhqAgq9OIiK/Q\ncQk+Zv1689AzFb2I9CSN7C3U1GQuzL7/PkRHW51GRHyJRvY+ZPt2uO02Fb2I9DyVvYUu760XEelp\nmsaxyPHjEBMDR47AwIFWpxERX6NpHB/x6qswbZqKXkR6h8reAoahKRwR6V0qewvs2QOXLsF991md\nRET8hcreApfvRqUPFYtIb9ECbS+7fOhZeTnceqvVaUTEV2mB1sv96U8wcaKKXkR6l8q+l+mG4iJi\nBU3j9KKDB82jjHXomYhcK03jeLH16+H731fRi0jv08i+lzQ1wahR8D//Y35yVkTkWmhk76WKiyEs\nTEUvItZQ2fcSLcyKiJU0jdMLTpwwjzE+fBgGDbI6jYj0BZrG8UKvvgpTp6roRcQ6KvseZhiawhER\n66nse9hHH0FjI0yYYHUSEfFnKvsepkPPRMQbaIG2B50/bx569r//CyNHWp1GRPoSLdB6kT/9Ce69\nV0UvItZT2feg/HzIyrI6hYiIpnF6TGWluShbXQ39+1udRkT6Gk3jeInLh56p6EXEG2hk3wOamuDb\n34Z33oExY6xOIyJ9kUb2XuCvfzV34ajoRcRbqOx7gBZmRcTbaBrHw06dgshIOHJEZ+GISM/RNI7F\nNm6E//gPFb2IeBeVvQcZBuTlaQpHRLyPyt6D/v53cDph4kSrk4iItKSy96C8PJg3T4eeiYj30QKt\nhzQ0gN0On30GISFWpxGRvq5HFmiLi4uJjo4mMjKSVatWtfp5XV0d06ZNIyEhgeTkZPbv3w9AdXU1\nkyZNYsyYMcTFxfHSSy91OpiveeMNuOceFb2IeKcOR/Yul4uoqChKSkoICQlh/PjxFBYWEhMT03zN\n4sWLGTRoEC+88AIHDhzg6aefpqSkhOPHj3P8+HHGjh3LuXPnuPPOO3nzzTdbPLevjOxTU+FHP4Lp\n061OIiL+wOMj+7KyMiIiIggLCyMoKIiMjAyKiopaXFNRUcGkSZMAiIqKwuFwcOrUKUaMGMHYsWMB\nCA4OJiYmhqNHj3bl/fiEL76A8nJ4+GGrk4iItC2wowtqa2sJDQ1t/t5ut/PRRx+1uCYhIYEtW7Yw\nYcIEysrKOHz4MDU1NQwdOrT5GofDwb59+0hOTm71dyxbtqz569TUVFJTU7vxVqyzYQPMnq1Dz0Sk\n55SWllJaWtrt53dY9rZObC1ZunQpCxcuJDExkfj4eBITE+nXr1/zz8+dO8fMmTNZs2YNwcHBrZ5/\nZdn7GpfLLPviYquTiEhfdvVAePny5V16fodlHxISQnV1dfP31dXV2O32FtcMHDiQ/Pz85u/Dw8O5\n7bbbAGhsbGTGjBnMnj2bqVOndimcL3jnHXNRNi7O6iQiIu3rcM4+KSmJyspKHA4HTqeTzZs3k56e\n3uKa+vp6nE4nAOvWrSMlJYXg4GAMwyArK4vY2FgWLVrUM+/AYnl55g3FRUS8Waf22W/fvp1Fixbh\ncrnIysri+eefJzc3F4Ds7Gx2797N3LlzsdlsxMXFkZeXx+DBg/nwww/5zne+wx133NE8HbRixQq+\n+93vfhPAh3fjXD707PBhGDzY6jQi4k+62p36UNU1ePFF2LsXCgqsTiIi/kanXvaSy4eeaQpHRHyB\nyr6bPv4YLlyAlBSrk4iIdExl3035+Tr0TER8h+bsu+HyoWeffmr+KSLS2zRn3wu2bIG771bRi4jv\nUNl3Q36+FmZFxLdoGqeLDh0yjzKuqdFZOCJiHU3j9LANG2DWLBW9iPgWjey7wOWCsDDYtg3i461O\nIyL+TCP7HvTuu3DrrSp6EfE9Kvsu0MKsiPgqTeN00pdfQkQEOBxw001WpxERf6dpnB6yaRM88oiK\nXkR8k8q+E3TomYj4OpV9J+zYAefP69AzEfFdHd6W0N/t3AmPPmouzgbon0YR8VGqLzeKi2HGDCgs\nhIcftjqNiEj3qezb8cc/QmYmFBXBAw9YnUZE5Nqo7NuQlweLFpkforrnHqvTiIhcO83ZX+W//gte\neglKS82biYuI9AUq+38xDPjZz+D1181F2dBQqxOJiHiOyh64dAmefdbcYrljBwwbZnUiERHP8vuy\nb2qCBQvgiy/g/ff1CVkR6Zv8uuy//hoef9y8p+w778ANN1idSESkZ/jtbpzz582zbgICzO2VKnoR\n6cv8suzPnIF/+zfzhuGvvQbXXWd1IhGRnuV3ZX/iBKSmwl13wR/+AIF+PZElIv7Cr8r+yBGYOBGm\nTTP30+usGxHxF35TdwcOmEX/wx+a++ltNqsTiYj0Hr+YxPjHP2DKFPjP/9SZ9CLin/p82e/aZU7b\n/O53MHOm1WlERKzRp8v+3Xdh1izYuBHS0qxOIyJinT47Z79lC8yebf6pohcRf9cny37DBnj6afPm\nIxMmWJ1GRMR6NsMwDEsD2Gx4MkJZmTk3/847EB3tsZcVEfEqXe3ODkf2xcXFREdHExkZyapVq1r9\nvK6ujmnTppGQkEBycjL79+/v9HN7wvjx5u4bTxd9aWmpZ1+whyinZymn5/hCRvCdnF3ltuxdLhfP\nPPMMxcXFlJeXU1hYSEVFRYtrcnJyGDduHJ988gkFBQUsXLiw08/tCTYb3HKL51/XV/4PoJyepZye\n4wsZwXdydpXbsi8rKyMiIoKwsDCCgoLIyMigqKioxTUVFRVMmjQJgKioKBwOBydPnuzUc0VEpHe4\nLfva2lpCr7hlk91up7a2tsU1CQkJbNmyBTD/cTh8+DA1NTWdeq6IiPQOt/vsbZ04U2Dp0qUsXLiQ\nxMRE4uPjSUxMpF+/fp16blf+Hm+wfPlyqyN0inJ6lnJ6ji9kBN/J2RVuyz4kJITq6urm76urq7Hb\n7S2uGThwIPn5+c3fh4eHM3r0aC5cuNDhcwGP7sQREZG2uZ3GSUpKorKyEofDgdPpZPPmzaSnp7e4\npr6+HqfTCcC6detISUkhODi4U88VEZHe4XZkHxgYyNq1a0lLS8PlcpGVlUVMTAy5ubkAZGdnU15e\nzty5c7HZbMTFxZGXl+f2uSIiYgHDQtu3bzeioqKMiIgIY+XKlVZGadeRI0eM1NRUIzY21hgzZoyx\nZs0aqyO1q6mpyRg7dqzx8MMPWx2lXXV1dcaMGTOM6OhoIyYmxti9e7fVkdqUk5NjxMbGGnFxccbj\njz9uXLx40epIhmEYxrx584xhw4YZcXFxzY+dPn3amDx5shEZGWk8+OCDRl1dnYUJTW3l/MlPfmJE\nR0cbd9xxhzFt2jTjzJkzFiY0tZXzsl//+teGzWYzTp8+bUGyltrL+dJLLxnR0dHGmDFjjJ/+9Kdu\nX8Oysm9qajJGjx5tVFVVGU6n00hISDDKy8utitOuY8eOGfv27TMMwzC++uor4/bbb/fKnIZhGL/5\nzW+MJ554wnjkkUesjtKuOXPmGHl5eYZhGEZjY6NX/MJfraqqyggPD28u+Mcee8zYsGGDxalMO3bs\nMPbu3dvil37x4sXGqlWrDMMwjJUrVxpLliyxKl6ztnK+8847hsvlMgzDMJYsWeK1OQ3DHOSlpaUZ\nYWFhXlH2beV87733jMmTJxtOp9MwDMM4efKk29ew7GwcX9mHP2LECMaOHQtAcHAwMTExHD161OJU\nrdXU1LBt2zYWLFjgtYve9fX17Ny5k/n/uqlAYGAggwcPtjhVa4MGDSIoKIiGhgaamppoaGggJCTE\n6lgATJw4kZtvvrnFY2+99RaZmZkAZGZm8uabb1oRrYW2cj744IME/Ov2cMnJydTU1FgRrYW2cgL8\n+Mc/5le/+pUFidrWVs7f//73PP/88wQFBQEwdOhQt69hWdn74j58h8PBvn37SE5OtjpKK88++yyr\nV69u/mXyRlVVVQwdOpR58+Yxbtw4nnzySRoaGqyO1cott9zCc889x6hRoxg5ciQ33XQTkydPtjpW\nu06cOMHw4cMBGD58OCdOnLA4Ucfy8/OZMmWK1THaVFRUhN1u54477rA6iluVlZXs2LGDu+++m9TU\nVD7++GO311vWDL6yt/6yc+fOMXPmTNasWUNwcLDVcVrYunUrw4YNIzEx0WtH9QBNTU3s3buXH/7w\nh+zdu5cbb7yRlStXWh2rlUOHDvHiiy/icDg4evQo586dY9OmTVbH6hSbzeb1v1u//OUv6d+/P088\n8YTVUVppaGggJyenxT57b/2dampqoq6ujj179rB69Woee+wxt9dbVvad2cPvLRobG5kxYwazZ89m\n6tSpVsdpZdeuXbz11luEh4fz+OOP89577zFnzhyrY7Vit9ux2+2MHz8egJkzZ7J3716LU7X28ccf\nc++99zJkyBACAwOZPn06u3btsjpWu4YPH87x48cBOHbsGMOGDbM4Ufs2bNjAtm3bvPYfz0OHDuFw\nOEhISCA8PJyamhruvPNOTp48aXW0Vux2O9OnTwdg/PjxBAQEcPr06Xavt6zsfWUfvmEYZGVlERsb\ny6JFi6yO06acnByqq6upqqritdde4/7776egoMDqWK2MGDGC0NBQDh48CEBJSQljxoyxOFVr0dHR\n7NmzhwsXLmAYBiUlJcTGxlodq13p6em88sorALzyyiteOSAB8xTc1atXU1RUxIABA6yO06b4+HhO\nnDhBVVUVVVVV2O129u7d65X/gE6dOpX33nsPgIMHD+J0OhkyZEj7T+ip1ePO2LZtm3H77bcbo0eP\nNnJycqyM0q6dO3caNpvNSEhIMMaOHWuMHTvW2L59u9Wx2lVaWurVu3H+8Y9/GElJSV61/a4tq1at\nat56OWfOnOYdD1bLyMgwbr31ViMoKMiw2+1Gfn6+cfr0aeOBBx7wqq2XV+fMy8szIiIijFGjRjX/\nHj311FNWx2zO2b9//+b/Pa8UHh7uFbtx2srpdDqN2bNnG3Fxcca4ceOM999/3+1rWH7zEhER6Xne\nu3VDREQ8RmUvIuIHVPYiIn5AZS8i4gdU9iIifkBlLyLiB/4f2TwW+6/CGsUAAAAASUVORK5CYII=\n",
       "text": [
        "<matplotlib.figure.Figure at 0x10785ba10>"
       ]
      }
     ],
     "prompt_number": 19
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.ensemble import RandomForestClassifier\n",
      "\n",
      "rf = RandomForestClassifier(n_estimators=100, n_jobs=2)\n",
      "%time rf.fit(X_train, y_train)\n",
      "%time rf.score(X_test, y_test)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "CPU times: user 702 ms, sys: 1.68 s, total: 2.38 s\n",
        "Wall time: 1min 37s\n"
       ]
      },
      {
       "output_type": "stream",
       "stream": "stdout",
       "text": [
        "CPU times: user 213 ms, sys: 321 ms, total: 533 ms\n",
        "Wall time: 1.72 s\n"
       ]
      },
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 8,
       "text": [
        "0.96830000000000005"
       ]
      }
     ],
     "prompt_number": 8
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from sklearn.ensemble import ExtraTreesClassifier\n",
      "\n",
      "et = ExtraTreesClassifier(n_estimators=1000, n_jobs=2)\n",
      "%time et.fit(X_train, y_train)\n",
      "%time et.score(X_test, y_test)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": "*"
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "from copy import copy\n",
      "from random import Random\n",
      "\n",
      "def model_path(model, all_n_estimators, seed=0):\n",
      "    rng = Random(seed)\n",
      "    models = []\n",
      "    for n_estimators in all_n_estimators:\n",
      "        model_copy = copy(model)\n",
      "        if n_estimators < len(model.estimators_):\n",
      "            sampled = rng.sample(model.estimators_, n_estimators)\n",
      "            model_copy.estimators_ = sampled\n",
      "        elif n_estimators > len(model.estimators_):\n",
      "            raise ValueError(\n",
      "                \"Cannot subsample %d estimators out of %d\"\n",
      "                % (n_estimators, len(model.estimators)))\n",
      "        model_copy.n_estimators = n_estimators\n",
      "        models.append(model_copy)\n",
      "    return models"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 13
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "all_n_estimators = [10, 30, 100, 300]"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": "*"
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "%time score_path = [m.score(X_test, y_test) for m in model_path(et, all_n_estimators)]\n",
      "score_path"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": "*"
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "x = np.array(all_n_estimators)\n",
      "plt.plot(x, np.array(score_path))\n",
      "plt.ylim(0.95, 1)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": "*"
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "import bloscpack as bp\n",
      "\n",
      "mnist8m_data = bp.unpack_ndarray_file(DATA_FOLDER + '/bloscpack/mnist8m-chunk-00.blp')\n",
      "mnist8m_labels = bp.unpack_ndarray_file(DATA_FOLDER + '/bloscpack/mnist8m-labels.blp')\n",
      "\n",
      "X_train_big = np.vstack([X_train, mnist8m_data[:40000]])\n",
      "y_train_big = np.concatenate([y_train, mnist8m_labels[:40000]])"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": 10
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "X_train_big.dtype"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [
      {
       "metadata": {},
       "output_type": "pyout",
       "prompt_number": 11,
       "text": [
        "dtype('float32')"
       ]
      }
     ],
     "prompt_number": 11
    },
    {
     "cell_type": "code",
     "collapsed": false,
     "input": [
      "et = ExtraTreesClassifier(n_estimators=300, n_jobs=2)\n",
      "%time et.fit(X_train_big, X_train_big)\n",
      "%time et.score(X_test, y_test)"
     ],
     "language": "python",
     "metadata": {},
     "outputs": [],
     "prompt_number": "*"
    }
   ],
   "metadata": {}
  }
 ]
}